# encoding: utf-8

"""
@Time: 2021-07-19 20:14 
@Author: Libing Wang
@File: cfgs.py 
@description:  相关的配置文件以及参数
"""

import os
import torch.optim as optim
from torchvision import transforms
from dataset_scene import LmdbDataset
from net.DAN_VSFM_classifier import FeatureExtractor, CAM, DTD


global_cfgs = {
    "state": "Train",
    "epochs": 10000,
    "show_interval": 5,
    "test_interval": 100,
    "devices": [0],
}


dataset_cfgs = {
    "dataset_train": LmdbDataset,
    "dataset_train_args": {
        "roots": os.path.join(os.getcwd(), 'lmdbdata', 'syndata', 'train'),
        "img_h": 32,
        "img_w": 128,
        "transform": transforms.Compose([transforms.ToTensor()]),
        "global_state": "Train",
    },
    "dataloader_train": {
        "batch_size": 256,
        "shuffle": True,
        "num_workers":  0,
    },
    "dataset_test": LmdbDataset,
    "dataset_test_args": {
        "roots": os.path.join(os.getcwd(), 'lmdbdata', 'syndata', 'test'),
        "img_h": 32,
        "img_w": 128,
        "transform": transforms.Compose([transforms.ToTensor()]),
        "global_state": "Test",
    },
    "dataloader_test": {
        "batch_size": 256,
        "shuffle": False,
        "num_workers": 0,
    },
    "case_sensitive": True,
    "dict_dir": "dict/dic_yn.txt",
}


net_cfgs = {
    "FE": FeatureExtractor,
    "FE_args": {
        "strides": [(1, 1), (2, 2), (1, 1), (2, 2), (1, 1), (1, 1)],
        "input_shape": [1, 32, 128],
    },
    "CAM": CAM,
    "CAM_args": {
        "maxT": 32,     # 时间长度改为 32
        "depth": 8,
        "num_channels": 64,
    },
    "DTD": DTD,
    "DTD_args": {
        # TODO 这里修改为类别数
        "num_class": 198,   # extra 2 classes for Unknown and End_token
        "num_channels": 512,
        "dropout": 0.3,
    },

    "init_state_dict_fe": None,
    "init_state_dict_cam": None,
    "init_state_dict_dtd": None,
}


optimizer_cfgs = {
    # optimizer for FE
    "optimizer_0": optim.Adadelta,
    "optimizer_0_args": {
        "lr": 1.0,
        "weight_decay": 0.0001,
    },
    "optimizer_0_scheduler": optim.lr_scheduler.MultiStepLR,
    "optimizer_0_scheduler_args": {
        # TODO 这里也需要调整
        "milestones": [500, 1200],
        "gamma": 0.1,
    },

    # optimizer for CAM
    "optimizer_1": optim.Adadelta,
    "optimizer_1_args": {
        "lr": 1.0,
        "weight_decay": 0.0001,
    },
    "optimizer_1_scheduler": optim.lr_scheduler.MultiStepLR,
    "optimizer_1_scheduler_args": {
        # TODO 这里也需要调整
        "milestones": [500, 1200],
        "gamma": 0.1,
    },

    # optimizer for DTD
    "optimizer_2": optim.Adadelta,
    "optimizer_2_args": {
        "lr": 1.0,
        "weight_decay": 0.0001,
    },
    "optimizer_2_scheduler": optim.lr_scheduler.MultiStepLR,
    "optimizer_2_scheduler_args": {
        # TODO 这里也需要调整
        "milestones": [500, 1200],
        "gamma": 0.1,
    },
}


saving_cfgs = {
    "saving_iter_interval": 2000,
    "saving_epoch_interval": 1,
    "saving_path": os.path.join(os.getcwd(), 'logs', 'logs_classifier', 'syn_model'),
}


def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def show_cfgs(s):
    for key in s.keys():
        print(key, s[key])
    print('')


make_dir(saving_cfgs["saving_path"])
